{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.10"
    },
    "colab": {
      "name": "18_Using_Reinforcement_Learning_to_Play_Pong.ipynb",
      "provenance": []
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m0jRtbRGsoZy",
        "colab_type": "text"
      },
      "source": [
        "# Tutorial Part 18: Using Reinforcement Learning to Play Pong\n",
        "\n",
        "This notebook demonstrates using reinforcement learning to train an agent to play Pong.\n",
        "\n",
        "The first step is to create an `Environment` that implements this task.  Fortunately,\n",
        "OpenAI Gym already provides an implementation of Pong (and many other tasks appropriate\n",
        "for reinforcement learning).  DeepChem's `GymEnvironment` class provides an easy way to\n",
        "use environments from OpenAI Gym.  We could just use it directly, but in this case we\n",
        "subclass it and preprocess the screen image a little bit to make learning easier.\n",
        "\n",
        "## Colab\n",
        "\n",
        "This tutorial and the rest in this sequence are designed to be done in Google colab. If you'd like to open this notebook in colab, you can use the following link.\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepchem/deepchem/blob/master/examples/tutorials/18_Using_Reinforcement_Learning_to_Play_Pong.ipynb)\n",
        "\n",
        "## Setup\n",
        "\n",
        "To run DeepChem within Colab, you'll need to run the following cell of installation commands. This will take about 5 minutes to run to completion and install your environment. To install `gym` you should also use `pip install 'gym[atari]'` (We need the extra modifier since we'll be using an atari game). We'll add this command onto our usual Colab installation commands for you"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qXdmcnhtst-z",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 170
        },
        "outputId": "5c7cf904-0f5c-41d8-c404-75258bafca86"
      },
      "source": [
        "!curl -Lo conda_installer.py https://raw.githubusercontent.com/deepchem/deepchem/master/scripts/colab_install.py\n",
        "import conda_installer\n",
        "conda_installer.install()\n",
        "!/root/miniconda/bin/conda info -e"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
            "                                 Dload  Upload   Total   Spent    Left  Speed\n",
            "\r  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0\r100  3489  100  3489    0     0  89461      0 --:--:-- --:--:-- --:--:-- 91815\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "add /root/miniconda/lib/python3.6/site-packages to PYTHONPATH\n",
            "all packages is already installed\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "# conda environments:\n",
            "#\n",
            "base                  *  /root/miniconda\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-1kpETs2GnbI",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 188
        },
        "outputId": "dc8d5ae6-a0d7-4236-8168-8b615806ce41"
      },
      "source": [
        "!pip install --pre deepchem\n",
        "import deepchem\n",
        "deepchem.__version__"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: deepchem in /usr/local/lib/python3.6/dist-packages (2.4.0rc1.dev20200805145259)\n",
            "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from deepchem) (0.16.0)\n",
            "Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from deepchem) (1.4.1)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from deepchem) (1.18.5)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from deepchem) (1.0.5)\n",
            "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from deepchem) (0.22.2.post1)\n",
            "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas->deepchem) (2.8.1)\n",
            "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->deepchem) (2018.9)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.6.1->pandas->deepchem) (1.15.0)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'2.4.0-rc1.dev'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 2
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9sv6kX_VsoZ1",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 187
        },
        "outputId": "ce4206d5-7917-4cad-c716-238a41f78e2a"
      },
      "source": [
        "!pip install 'gym[atari]'"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: gym[atari] in /usr/local/lib/python3.6/dist-packages (0.17.2)\n",
            "Requirement already satisfied: cloudpickle<1.4.0,>=1.2.0 in /usr/local/lib/python3.6/dist-packages (from gym[atari]) (1.3.0)\n",
            "Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from gym[atari]) (1.4.1)\n",
            "Requirement already satisfied: pyglet<=1.5.0,>=1.4.0 in /usr/local/lib/python3.6/dist-packages (from gym[atari]) (1.5.0)\n",
            "Requirement already satisfied: numpy>=1.10.4 in /usr/local/lib/python3.6/dist-packages (from gym[atari]) (1.18.5)\n",
            "Requirement already satisfied: Pillow; extra == \"atari\" in /usr/local/lib/python3.6/dist-packages (from gym[atari]) (7.0.0)\n",
            "Requirement already satisfied: opencv-python; extra == \"atari\" in /usr/local/lib/python3.6/dist-packages (from gym[atari]) (4.1.2.30)\n",
            "Requirement already satisfied: atari-py~=0.2.0; extra == \"atari\" in /usr/local/lib/python3.6/dist-packages (from gym[atari]) (0.2.6)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from pyglet<=1.5.0,>=1.4.0->gym[atari]) (0.16.0)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from atari-py~=0.2.0; extra == \"atari\"->gym[atari]) (1.15.0)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EuRrb3vpsoZ_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import deepchem as dc\n",
        "import numpy as np\n",
        "\n",
        "class PongEnv(dc.rl.GymEnvironment):\n",
        "  def __init__(self):\n",
        "    super(PongEnv, self).__init__('Pong-v0')\n",
        "    self._state_shape = (80, 80)\n",
        "  \n",
        "  @property\n",
        "  def state(self):\n",
        "    # Crop everything outside the play area, reduce the image size,\n",
        "    # and convert it to black and white.\n",
        "    cropped = np.array(self._state)[34:194, :, :]\n",
        "    reduced = cropped[0:-1:2, 0:-1:2]\n",
        "    grayscale = np.sum(reduced, axis=2)\n",
        "    bw = np.zeros(grayscale.shape)\n",
        "    bw[grayscale != 233] = 1\n",
        "    return bw\n",
        "\n",
        "  def __deepcopy__(self, memo):\n",
        "    return PongEnv()\n",
        "\n",
        "env = PongEnv()"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GNnO3MZ_soaG",
        "colab_type": "text"
      },
      "source": [
        "Next we create a network to implement the policy.  We begin with two convolutional layers to process\n",
        "the image.  That is followed by a dense (fully connected) layer to provide plenty of capacity for game\n",
        "logic.  We also add a small Gated Recurrent Unit.  That gives the network a little bit of memory, so\n",
        "it can keep track of which way the ball is moving.\n",
        "\n",
        "We concatenate the dense and GRU outputs together, and use them as inputs to two final layers that serve as the\n",
        "network's outputs.  One computes the action probabilities, and the other computes an estimate of the\n",
        "state value function.\n",
        "\n",
        "We also provide an input for the initial state of the GRU, and returned its final state at the end.  This is required by the learning algorithm"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BLdt8WAQsoaH",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import tensorflow as tf\n",
        "from tensorflow.keras.layers import Input, Concatenate, Conv2D, Dense, Flatten, GRU, Reshape\n",
        "\n",
        "class PongPolicy(dc.rl.Policy):\n",
        "    def __init__(self):\n",
        "        super(PongPolicy, self).__init__(['action_prob', 'value', 'rnn_state'], [np.zeros(16)])\n",
        "\n",
        "    def create_model(self, **kwargs):\n",
        "        state = Input(shape=(80, 80))\n",
        "        rnn_state = Input(shape=(16,))\n",
        "        conv1 = Conv2D(16, kernel_size=8, strides=4, activation=tf.nn.relu)(Reshape((80, 80, 1))(state))\n",
        "        conv2 = Conv2D(32, kernel_size=4, strides=2, activation=tf.nn.relu)(conv1)\n",
        "        dense = Dense(256, activation=tf.nn.relu)(Flatten()(conv2))\n",
        "        gru, rnn_final_state = GRU(16, return_state=True, return_sequences=True)(\n",
        "            Reshape((-1, 256))(dense), initial_state=rnn_state)\n",
        "        concat = Concatenate()([dense, Reshape((16,))(gru)])\n",
        "        action_prob = Dense(env.n_actions, activation=tf.nn.softmax)(concat)\n",
        "        value = Dense(1)(concat)\n",
        "        return tf.keras.Model(inputs=[state, rnn_state], outputs=[action_prob, value, rnn_final_state])\n",
        "\n",
        "policy = PongPolicy()"
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YU19h0aUsoaN",
        "colab_type": "text"
      },
      "source": [
        "We will optimize the policy using the Asynchronous Advantage Actor Critic (A3C) algorithm.  There are lots of hyperparameters we could specify at this point, but the default values for most of them work well on this problem.  The only one we need to customize is the learning rate."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "scrolled": true,
        "id": "Fw_wu511soaO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# from deepchem.models.optimizers import Adam\n",
        "# a3c = dc.rl.A3C(env, policy, model_dir='model', optimizer=Adam(learning_rate=0.0002))"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-PUD4JG2soaU",
        "colab_type": "text"
      },
      "source": [
        "Optimize for as long as you have patience to.  By 1 million steps you should see clear signs of learning.  Around 3 million steps it should start to occasionally beat the game's built in AI.  By 7 million steps it should be winning almost every time.  Running on my laptop, training takes about 20 minutes for every million steps."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Wa18EQlmsoaV",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# # Change this to train as many steps as you have patience for.\n",
        "# a3c.fit(1000)"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_xHNjusSsoaa",
        "colab_type": "text"
      },
      "source": [
        "Let's watch it play and see how it does! "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ud6DB_ndsoab",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# # This code doesn't work well on Colab\n",
        "# env.reset()\n",
        "# while not env.terminated:\n",
        "#     env.env.render()\n",
        "#     env.step(a3c.select_action(env.state))"
      ],
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3MGK4nrhsoah",
        "colab_type": "text"
      },
      "source": [
        "# Congratulations! Time to join the Community!\n",
        "\n",
        "Congratulations on completing this tutorial notebook! If you enjoyed working through the tutorial, and want to continue working with DeepChem, we encourage you to finish the rest of the tutorials in this series. You can also help the DeepChem community in the following ways:\n",
        "\n",
        "## Star DeepChem on [GitHub](https://github.com/deepchem/deepchem)\n",
        "This helps build awareness of the DeepChem project and the tools for open source drug discovery that we're trying to build.\n",
        "\n",
        "## Join the DeepChem Gitter\n",
        "The DeepChem [Gitter](https://gitter.im/deepchem/Lobby) hosts a number of scientists, developers, and enthusiasts interested in deep learning for the life sciences. Join the conversation!"
      ]
    }
  ]
}